# Copyright 2023 The Langfun Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image tests."""
import io
import unittest
from unittest import mock

import langfun.core as lf
from langfun.core.modalities import image as image_lib
from langfun.core.modalities import mime as mime_lib
import PIL.Image as pil_image
import pyglove as pg


image_content = (
    b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
    b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
    b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
    b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
    b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
    b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
    b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
    b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
)


def mock_request(*args, **kwargs):
  del args, kwargs
  return pg.Dict(content=image_content)


class ImageTest(unittest.TestCase):

  def test_from_bytes(self):
    image = image_lib.Image.from_bytes(image_content)
    self.assertEqual(image.image_format, 'png')
    self.assertIn('data:image/png;base64,', image._raw_html())
    self.assertEqual(image.to_bytes(), image_content)
    with self.assertRaisesRegex(
        lf.ModalityError, '.* cannot be converted to text'
    ):
      image.to_text()

  def test_from_bytes_invalid(self):
    image = image_lib.Image.from_bytes(b'bad')
    with self.assertRaisesRegex(ValueError, 'Expected MIME type'):
      _ = image.image_format

  def test_from_bytes_base_cls(self):
    self.assertIsInstance(
        mime_lib.Mime.from_bytes(image_content), image_lib.Image
    )

  def test_from_uri(self):
    image = image_lib.Image.from_uri('http://mock/web/a.png')
    with mock.patch('requests.get') as mock_requests_get:
      mock_requests_get.side_effect = mock_request
      self.assertEqual(image.image_format, 'png')
      self.assertEqual(
          image._raw_html(),
          '<img src="http://mock/web/a.png">'
      )
      self.assertEqual(image.to_bytes(), image_content)

  def test_from_uri_base_cls(self):
    with mock.patch('requests.get') as mock_requests_get:
      mock_requests_get.side_effect = mock_request
      image = mime_lib.Mime.from_uri('http://mock/web/a.png')
      self.assertIsInstance(image, image_lib.Image)
      self.assertEqual(image.image_format, 'png')
      self.assertEqual(
          image._raw_html(),
          '<img src="http://mock/web/a.png">'
      )
      self.assertEqual(image.to_bytes(), image_content)

  def test_image_size(self):
    image = image_lib.Image.from_uri('http://mock/web/a.png')
    with mock.patch('requests.get') as mock_requests_get:
      mock_requests_get.side_effect = mock_request
      self.assertEqual(image.size, (24, 24))

  def test_to_pil_image(self):
    image = image_lib.Image.from_uri('http://mock/web/a.png')
    with mock.patch('requests.get') as mock_requests_get:
      mock_requests_get.side_effect = mock_request
      self.assertIsInstance(image.to_pil_image(), pil_image.Image)

  def test_from_pil_image(self):
    image = pil_image.open(io.BytesIO(image_content))
    self.assertIsInstance(
        image_lib.Image.from_pil_image(image), image_lib.Image
    )

  def test_from_pil_image_os_error(self):
    img = pil_image.open(io.BytesIO(image_content))
    with mock.patch.object(img, 'save') as mock_save:
      mock_save.side_effect = [OSError, None]
      with mock.patch('os.chdir') as mock_chdir:
        with mock.patch('os.getcwd') as mock_getcwd:
          mock_getcwd.return_value = '/curr/dir'
          image = image_lib.Image.from_pil_image(img)
          self.assertIsInstance(image, image_lib.Image)
          self.assertEqual(mock_save.call_count, 2)
          mock_save.assert_has_calls([
              mock.call(mock.ANY, format='PNG'),
              mock.call(mock.ANY, format='PNG'),
          ])
          mock_chdir.assert_has_calls([
              mock.call('/tmp'),
              mock.call('/curr/dir'),
          ])

  def test_gif_is_compatible(self):
    # Create a simple 1x1 GIF image using PIL
    buf = io.BytesIO()
    img = pil_image.new('P', (1, 1))
    img.save(buf, format='GIF')
    gif_bytes = buf.getvalue()

    gif_image = image_lib.Image.from_bytes(gif_bytes)
    self.assertEqual(gif_image.mime_type, 'image/gif')

    # GIF should be compatible if PNG is in supported types
    self.assertTrue(gif_image._is_compatible(['image/png']))
    self.assertTrue(gif_image._is_compatible(['image/jpeg', 'image/webp']))
    self.assertTrue(gif_image._is_compatible(['image/png', 'image/jpeg']))

    # GIF should not be compatible if only unsupported types
    self.assertFalse(gif_image._is_compatible(['video/mp4']))
    self.assertFalse(gif_image._is_compatible(['application/pdf']))

  def test_gif_make_compatible(self):
    # Create a simple 1x1 GIF image using PIL
    buf = io.BytesIO()
    img = pil_image.new('P', (1, 1))
    img.save(buf, format='GIF')
    gif_bytes = buf.getvalue()

    gif_image = image_lib.Image.from_bytes(gif_bytes)
    self.assertEqual(gif_image.mime_type, 'image/gif')

    # Test 1: Convert to PNG (first priority when available)
    converted = gif_image.make_compatible(['image/png', 'image/jpeg'])
    self.assertEqual(converted.mime_type, 'image/png')
    self.assertIsInstance(converted, image_lib.Image)

    # Test 2: Convert to JPEG when PNG not supported
    converted = gif_image.make_compatible(['image/jpeg', 'image/webp'])
    self.assertEqual(converted.mime_type, 'image/jpeg')

    # Test 3: Convert to WEBP when PNG and JPEG not supported
    converted = gif_image.make_compatible(['image/webp'])
    self.assertEqual(converted.mime_type, 'image/webp')

    # Test 4: Should raise error when no compatible format
    with self.assertRaises(lf.ModalityError):
      gif_image.make_compatible(['video/mp4'])

  def test_is_compatible_direct_match(self):
    image = image_lib.Image.from_bytes(image_content)  # image/png
    self.assertTrue(image._is_compatible(['image/png', 'image/jpeg']))
    self.assertTrue(image._is_compatible(['image/png']))
    self.assertFalse(image._is_compatible(['image/jpeg']))

  def test_make_compatible_no_conversion(self):
    image = image_lib.Image.from_bytes(image_content)  # image/png
    converted_image = image.make_compatible(['image/png', 'image/jpeg'])
    self.assertIs(image, converted_image)

  def test_convert_to_format_jpeg_transparency(self):
    # Create a simple RGBA PNG image
    buf = io.BytesIO()
    img = pil_image.new('RGBA', (1, 1), (255, 0, 0, 128))
    img.save(buf, format='PNG')
    rgba_png_bytes = buf.getvalue()

    rgba_image = image_lib.Image.from_bytes(rgba_png_bytes)
    self.assertEqual(rgba_image.mime_type, 'image/png')

    # Convert to JPEG, should trigger transparency handling
    converted_image = rgba_image._convert_to_format('JPEG')
    self.assertEqual(converted_image.mime_type, 'image/jpeg')
    pil_img = converted_image.to_pil_image()
    self.assertEqual(pil_img.mode, 'RGB')

  def test_convert_to_format_os_error(self):
    image = image_lib.Image.from_bytes(image_content)
    mock_pil_image = mock.MagicMock()
    mock_save = mock_pil_image.save
    mock_save.side_effect = [OSError, None]

    with mock.patch.object(
        image, 'to_pil_image', return_value=mock_pil_image
    ), mock.patch('os.chdir') as mock_chdir, mock.patch(
        'os.getcwd'
    ) as mock_getcwd:
      mock_getcwd.return_value = '/curr/dir'
      converted_image = image._convert_to_format('PNG')
      self.assertIsInstance(converted_image, image_lib.Image)
      self.assertEqual(mock_save.call_count, 2)
      mock_save.assert_has_calls([
          mock.call(mock.ANY, format='PNG'),
          mock.call(mock.ANY, format='PNG'),
      ])
      mock_chdir.assert_has_calls([
          mock.call('/tmp'),
          mock.call('/curr/dir'),
      ])


if __name__ == '__main__':
  unittest.main()
